{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# From expectation maximization to stochastic variational inference\n",
    "\n",
    "**Update, Dec. 17<sup>th</sup> 2019**: This notebook is superseded by the following two notebooks:\n",
    "\n",
    "- [Latent variable models - part 1: Gaussian mixture models and the EM algorithm](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/master/latent_variable_models_part_1.ipynb)\n",
    "- [Latent variable models - part 2: Stochastic variational inference and variational autoencoders](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/master/latent_variable_models_part_2.ipynb).\n",
    "\n",
    "The following old variational autoencoder code<sup>[1]</sup> is still used in [other](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/master/variational_autoencoder_opt.ipynb) [notebooks](https://nbviewer.jupyter.org/github/krasserm/bayesian-machine-learning/blob/master/variational_autoencoder_dfc.ipynb) and kept here for further reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import cm\n",
    "\n",
    "import keras\n",
    "from keras import backend as K\n",
    "from keras import layers\n",
    "from keras.datasets import mnist\n",
    "from keras.models import Model, Sequential\n",
    "from keras.utils import to_categorical\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Dimensions of MNIST images  \n",
    "image_shape = (28, 28, 1)\n",
    "\n",
    "# Dimension of latent space\n",
    "latent_dim = 2\n",
    "\n",
    "# Mini-batch size for training\n",
    "batch_size = 128\n",
    "\n",
    "def create_encoder():\n",
    "    '''\n",
    "    Creates a convolutional encoder model for MNIST images.\n",
    "    \n",
    "    - Input for the created model are MNIST images. \n",
    "    - Output of the created model are the sufficient statistics\n",
    "      of the variational distriution q(t|x;phi), mean and log \n",
    "      variance. \n",
    "    '''\n",
    "    encoder_iput = layers.Input(shape=image_shape)\n",
    "    \n",
    "    x = layers.Conv2D(32, 3, padding='same', activation='relu')(encoder_iput)\n",
    "    x = layers.Conv2D(64, 3, padding='same', activation='relu', strides=(2, 2))(x)\n",
    "    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)\n",
    "    x = layers.Conv2D(64, 3, padding='same', activation='relu')(x)\n",
    "    x = layers.Flatten()(x)\n",
    "    x = layers.Dense(32, activation='relu')(x)\n",
    "\n",
    "    t_mean = layers.Dense(latent_dim)(x)\n",
    "    t_log_var = layers.Dense(latent_dim)(x)\n",
    "\n",
    "    return Model(encoder_iput, [t_mean, t_log_var], name='encoder')\n",
    "\n",
    "def create_decoder():\n",
    "    '''\n",
    "    Creates a (de-)convolutional decoder model for MNIST images.\n",
    "    \n",
    "    - Input for the created model are latent vectors t.\n",
    "    - Output of the model are images of shape (28, 28, 1) where\n",
    "      the value of each pixel is the probability of being white.\n",
    "    '''\n",
    "    decoder_input = layers.Input(shape=(latent_dim,))\n",
    "    \n",
    "    x = layers.Dense(12544, activation='relu')(decoder_input)\n",
    "    x = layers.Reshape((14, 14, 64))(x)\n",
    "    x = layers.Conv2DTranspose(32, 3, padding='same', activation='relu', strides=(2, 2))(x)\n",
    "    x = layers.Conv2D(1, 3, padding='same', activation='sigmoid')(x)\n",
    "    \n",
    "    return Model(decoder_input, x, name='decoder')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample(args):\n",
    "    '''\n",
    "    Draws samples from a standard normal and scales the samples with\n",
    "    standard deviation of the variational distribution and shifts them\n",
    "    by the mean.\n",
    "    \n",
    "    Args:\n",
    "        args: sufficient statistics of the variational distribution.\n",
    "        \n",
    "    Returns:\n",
    "        Samples from the variational distribution.\n",
    "    '''\n",
    "    t_mean, t_log_var = args\n",
    "    t_sigma = K.sqrt(K.exp(t_log_var))\n",
    "    epsilon = K.random_normal(shape=K.shape(t_mean), mean=0., stddev=1.)\n",
    "    return t_mean + t_sigma * epsilon\n",
    "\n",
    "def create_sampler():\n",
    "    '''\n",
    "    Creates a sampling layer.\n",
    "    '''\n",
    "    return layers.Lambda(sample, name='sampler')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = create_encoder()\n",
    "decoder = create_decoder()\n",
    "sampler = create_sampler()\n",
    "\n",
    "x = layers.Input(shape=image_shape)\n",
    "t_mean, t_log_var = encoder(x)\n",
    "t = sampler([t_mean, t_log_var])\n",
    "t_decoded = decoder(t)\n",
    "\n",
    "vae = Model(x, t_decoded, name='vae')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def neg_variational_lower_bound(x, t_decoded):\n",
    "    '''\n",
    "    Negative variational lower bound used as loss function\n",
    "    for training the variational autoencoder.\n",
    "    \n",
    "    Args:\n",
    "        x: input images\n",
    "        t_decoded: reconstructed images\n",
    "    '''\n",
    "    # Reconstruction loss\n",
    "    rc_loss = K.sum(K.binary_crossentropy(\n",
    "        K.batch_flatten(x), \n",
    "        K.batch_flatten(t_decoded)), axis=-1)\n",
    "\n",
    "    # Regularization term (KL divergence)\n",
    "    kl_loss = -0.5 * K.sum(1 + t_log_var \\\n",
    "                             - K.square(t_mean) \\\n",
    "                             - K.exp(t_log_var), axis=-1)\n",
    "    \n",
    "    # Average over mini-batch\n",
    "    return K.mean(rc_loss + kl_loss)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "colab": {
     "autoexec": {
      "startup": false,
      "wait_interval": 0
     },
     "base_uri": "https://localhost:8080/",
     "height": 901,
     "output_extras": [
      {
       "item_id": 44
      },
      {
       "item_id": 45
      }
     ]
    },
    "colab_type": "code",
    "executionInfo": {
     "elapsed": 469640,
     "status": "ok",
     "timestamp": 1522825755182,
     "user": {
      "displayName": "Martin Krasser",
      "photoUrl": "//lh6.googleusercontent.com/-zVZZRiAWOs4/AAAAAAAAAAI/AAAAAAAAAlk/Q2XGRf45rYM/s50-c-k-no/photo.jpg",
      "userId": "115420131270379583938"
     },
     "user_tz": -120
    },
    "id": "IcArhIgEbNoU",
    "outputId": "7fdb64c3-6c04-4b6d-bc88-32a66c86e467"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train on 60000 samples, validate on 10000 samples\n",
      "Epoch 1/25\n",
      " - 20s - loss: 195.2594 - val_loss: 174.2528\n",
      "Epoch 2/25\n",
      " - 19s - loss: 171.3044 - val_loss: 166.7014\n",
      "Epoch 3/25\n",
      " - 19s - loss: 164.3512 - val_loss: 161.6999\n",
      "Epoch 4/25\n",
      " - 19s - loss: 160.1977 - val_loss: 161.6477\n",
      "Epoch 5/25\n",
      " - 19s - loss: 157.7647 - val_loss: 156.3657\n",
      "Epoch 6/25\n",
      " - 19s - loss: 156.1030 - val_loss: 155.2672\n",
      "Epoch 7/25\n",
      " - 19s - loss: 154.8048 - val_loss: 155.3826\n",
      "Epoch 8/25\n",
      " - 19s - loss: 153.7678 - val_loss: 152.8838\n",
      "Epoch 9/25\n",
      " - 19s - loss: 152.8601 - val_loss: 153.5024\n",
      "Epoch 10/25\n",
      " - 19s - loss: 152.0742 - val_loss: 152.3024\n",
      "Epoch 11/25\n",
      " - 19s - loss: 151.3636 - val_loss: 152.1116\n",
      "Epoch 12/25\n",
      " - 19s - loss: 150.7016 - val_loss: 152.2413\n",
      "Epoch 13/25\n",
      " - 19s - loss: 150.1563 - val_loss: 151.3501\n",
      "Epoch 14/25\n",
      " - 19s - loss: 149.6736 - val_loss: 149.6158\n",
      "Epoch 15/25\n",
      " - 19s - loss: 149.1454 - val_loss: 149.2047\n",
      "Epoch 16/25\n",
      " - 19s - loss: 148.7548 - val_loss: 149.0340\n",
      "Epoch 17/25\n",
      " - 19s - loss: 148.3118 - val_loss: 148.8142\n",
      "Epoch 18/25\n",
      " - 19s - loss: 147.9456 - val_loss: 149.1827\n",
      "Epoch 19/25\n",
      " - 19s - loss: 147.6405 - val_loss: 151.5525\n",
      "Epoch 20/25\n",
      " - 19s - loss: 147.2038 - val_loss: 148.2406\n",
      "Epoch 21/25\n",
      " - 19s - loss: 146.8990 - val_loss: 150.3239\n",
      "Epoch 22/25\n",
      " - 19s - loss: 146.6422 - val_loss: 147.5277\n",
      "Epoch 23/25\n",
      " - 18s - loss: 146.3777 - val_loss: 148.2235\n",
      "Epoch 24/25\n",
      " - 19s - loss: 146.0520 - val_loss: 147.7608\n",
      "Epoch 25/25\n",
      " - 19s - loss: 145.8438 - val_loss: 147.9813\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<keras.callbacks.History at 0x7fbaf4e40da0>"
      ]
     },
     "execution_count": 6,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# MNIST training and validation data\n",
    "(x_train, _), (x_test, _) = mnist.load_data()\n",
    "\n",
    "x_train = x_train.astype('float32') / 255.\n",
    "x_train = x_train.reshape(x_train.shape + (1,))\n",
    "\n",
    "x_test = x_test.astype('float32') / 255.\n",
    "x_test = x_test.reshape(x_test.shape + (1,))\n",
    "\n",
    "# Compile variational autoencoder model\n",
    "vae.compile(optimizer='rmsprop', loss=neg_variational_lower_bound)\n",
    "\n",
    "# Train variational autoencoder with MNIST images\n",
    "vae.fit(x=x_train, \n",
    "         y=x_train,\n",
    "         epochs=25,\n",
    "         shuffle=True,\n",
    "         batch_size=batch_size,\n",
    "         validation_data=(x_test, x_test), verbose=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## References\n",
    "\n",
    "\\[1\\] François Chollet. [Deep Learning with Python](https://www.manning.com/books/deep-learning-with-python).  "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
